from environment import Plane
from RL_Core import SarsaTable
import matplotlib.pyplot as plt

    #更新迭代函数
def update(plane_num, iteration_num):

    # 1.初始化
    ob = env.start()
    records = []

    # 2.开始迭代
    for episode in range(iteration_num):
        print(episode)
        # 2.1 循环动作,留一个动作0, 为保持不变
        for action in range(1 , 22):

            # 2.1.1 判断飞机是否可以向前移动,可以执行，不可以，得0
            observation = ob.copy()
            if env.plane_change(action, observation) == True:
                dtime = env.get_reward(observation, action)
            else:
                dtime = float("inf")

            # 2.1.2 拿到延误数据，写入Q表
            RL.writing_in_sarsa_table(action, dtime, ob)

        # 2.2 执行最佳动作,跳转到下一个状态,并记录该状态下的延误值
        dtime , next_observation = RL.choose_action(observation)

        #2.3 根据动作重新更新一次Q表的策略
        RL.learn(observation , action , dtime , next_observation)

        ob = next_observation
        print(ob)
        records.append(dtime)

    return records


if __name__ == "__main__":
    
    #飞机个数迭代次数
    plane_num = 20  #改：飞机的个数
    iteration_num = 50  #改：迭代的次数

    #环境类初始化env，得到FCFS状态的飞机调度记录总延误时间与单架飞机调度时间
    env = Plane(plane_num)
    line_fcfs_record , fcfs_record , fc_sum = env.Record(plane_num , iteration_num)

    #强化学习类初始化RL，将动作传入
    RL = SarsaTable(actions = list(range(env.n_actions)))

    #强化学习更新函数，拿到每次记录的迭代数据
    line_rl_record = update(plane_num , iteration_num)
    print(line_rl_record)

    # 绘制图像
    plt.plot(line_rl_record , label = "RL" , marker = "o" , color = "blue" , linestyle = "-")
    plt.plot(line_fcfs_record , label = "FCFS" , marker = "o" , color = "red" , linestyle = "-")
    plt.ylabel("Delay Time")
    plt.xlabel("Number of Iteration")
    plt.legend()
    plt.show()

